package com.rtsapp.server.network.protocol.crypto.client;

import com.rtsapp.server.logger.Logger;
import com.rtsapp.server.logger.LoggerFactory;
import com.rtsapp.server.network.protocol.crypto.IAppChannelInitializer;
import com.rtsapp.server.network.protocol.crypto.RC4InHandler;
import com.rtsapp.server.network.protocol.crypto.RC4OutHandler;
import com.rtsapp.server.network.protocol.crypto.KeyUtils;
import com.rtsapp.server.network.protocol.crypto.cmd.IRSAHandshakeStartCMD;
import com.rtsapp.server.network.protocol.crypto.cmd.RSAHandshakeCMDType;
import com.rtsapp.server.network.protocol.crypto.cmd.RSAHandshakeCompeleteCMD;
import com.rtsapp.server.network.protocol.crypto.cmd.RSAHandshakeStartCMD;
import com.rtsapp.server.utils.cypto.RSA;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;

import javax.crypto.NoSuchPaddingException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;

/**
 * RSA连接握手管理
 * <p>
 * 证书交换时: 私钥加密, 公钥解密
 * * 发送时: 服务器使用服务器私钥加密，客户端使用服务器公钥进行解密
 * * 接收时: 客户端使用客户端私钥加密, 服务器使用客户端公钥进行解密
 * <p>
 * * 用于验证双方持有对方公钥
 */
public class ClientRSAHandshake {

    private static final Logger LOG = LoggerFactory.getLogger(ClientRSAHandshake.class);


    //握手阶段输入协议的长度限制
    private static int MAX_INPUT_FRAME_LEH = 1024;
    private static int LENGTH_FIELD_OFFSET = 0;
    private static int LENTH_FIELD_LENGTH = 4;

    // 通道
    private final SocketChannel channel;
    // 原始的通道Initializer
    private final IAppChannelInitializer originChannelInitializer;
    //RSA 证书
    private final ClientRSAHandshakeKey keyRSA;
    //服务器私钥
    private final RSA clientPrivateRSA;
    //客户端公钥
    private final RSA serverPublicRSA;

    //生成的服务器通信加密key
    private  byte[] serverCryptKey;
    //生成的客户端通信加密key
    private  byte[] clientCryptKey;

    // 握手状态
    private HandshakeState state = HandshakeState.STATE_INIT_RSA_HANDERS;

    public ClientRSAHandshake(SocketChannel channel, IAppChannelInitializer originChannelInitializer, ClientRSAHandshakeKey keyRSA) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidKeySpecException {
        this.channel = channel;
        this.originChannelInitializer = originChannelInitializer;
        this.keyRSA = keyRSA;

        if (keyRSA != null) {
            clientPrivateRSA = new RSA();
            clientPrivateRSA.initPrivateKey(keyRSA.getN_CLIENT(), keyRSA.getE_CLIENT(), keyRSA.getD_CLIENT());

            serverPublicRSA = new RSA();
            serverPublicRSA.initPublicKey(keyRSA.getN_SERVER(), keyRSA.getE_SERVER());

        } else {
            clientPrivateRSA = null;
            serverPublicRSA = null;
        }

        changeState(HandshakeState.STATE_INIT_RSA_HANDERS );
    }

    /**
     * 握手连接启动
     */
    public void start() throws Exception {

        //如果是加密, 初始化加密环境, 否则直接发送握手完成, 初始化原有的handler
        if (isCrypt()) {
            initRSAHandlers();
        } else {
            initOriginHandlers();
        }
    }

    /**
     * 初始RSAHandler
     */
    public void initRSAHandlers() {


        try {
            // 顺序加入两个handler
            channel.pipeline().addFirst(new LengthFieldBasedFrameDecoder(this.MAX_INPUT_FRAME_LEH, this.LENGTH_FIELD_OFFSET, this.LENTH_FIELD_LENGTH));
            channel.pipeline().addLast(new ClientRSAHandshakeHandler(this));

            //3. 状态到
            changeState(HandshakeState.STATE_WAIT_RECV_CMD);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "initRSAHandlers");
        }
    }

    /**
     * 接收命令
     * @param buffer
     */
    public void recvData(ByteBuf buffer) {

        if (state == HandshakeState.STATE_WAIT_RECV_CMD ) {

            int commandId =  buffer.getInt( buffer.readerIndex() );

            if( commandId == RSAHandshakeCMDType.START ){
                recvStartCMD(  buffer );
            }else if( commandId == RSAHandshakeCMDType.COMPLETE ){
                recvCompleteCMD( buffer );
            }else{
                closeUnexpectedly("recvData error: commandId != START && commandId != COMPLETE" );
            }
        } else {
            closeUnexpectedly("recvData error:  state != STATE_WAIT_RECV_CMD");
        }
    }

    /**
     * 接收握手开始协议
     * @param buffer
     */
    public void recvStartCMD(ByteBuf buffer) {

        try {
            //1. 接收命令
            RSAHandshakeStartCMD cmd = new RSAHandshakeStartCMD();
            cmd.readFromBuffer(buffer);


            //2. 客户端私钥解密
            serverCryptKey = clientPrivateRSA.decrypt(cmd.getServerCryptKey());
            clientCryptKey = clientPrivateRSA.decrypt(cmd.getClientCryptKey());

            //4. 服务器公钥再加密
            cmd.setServerCryptKey(serverPublicRSA.encrypt(serverCryptKey));
            cmd.setClientCryptKey(serverPublicRSA.encrypt(clientCryptKey));


            //5. 将响应发送给客户端
            sendData( cmd );


        } catch (Throwable ex) {
            closeUnexpectedly(ex, "recvStartAckCMD error");
        }
    }

    /**
     * 接收握手结束协议
     */
    public void recvCompleteCMD( ByteBuf buffer ) {

        try {
            //1. 接收命令
            RSAHandshakeCompeleteCMD cmd = new RSAHandshakeCompeleteCMD();
            cmd.readFromBuffer(buffer);

            //2. 进入握手结束阶段
            if( cmd.isCrypt() ){
                initRC4Handlers();
            }else{
                initOriginHandlers();
            }

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "recvStartAckCMD error");
        }

    }

    public void initRC4Handlers() {

        try {
            ChannelPipeline pipeline = channel.pipeline();

            //1. 清除管道中的所有Handler
            while (pipeline.first() != null) {
                pipeline.removeFirst();
            }


            //2. 调用原有的originChannel 初始化Handler
            originChannelInitializer.initChannel(channel);

            //3. 将RC4Handler加入到管道
            pipeline.addFirst(new RC4InHandler(serverCryptKey));
            pipeline.addFirst(new RC4OutHandler(clientCryptKey));

            channel.pipeline().fireChannelRegistered();
            channel.pipeline().fireChannelActive();

            //4. 状态为握手退出
            changeState(HandshakeState.STATE_EXIT);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "initRC4Handlers error");
        }
    }

    private void initOriginHandlers() throws Exception {
        originChannelInitializer.initChannel(channel);
        changeState(HandshakeState.STATE_EXIT);
    }


    /**
     * 异常关闭
     */
    public void closeUnexpectedly(String message) {
        closeUnexpectedly(null, message);
    }

    public void closeUnexpectedly(Throwable ex, String message) {
        LOG.error(ex, "[RSAHandshake] closeUnexpectedly, channel={}, message={}", channel, message);
        this.channel.close();
    }

    private boolean isCrypt() {
        return keyRSA != null;
    }

    private void changeState(HandshakeState state) {
        this.state = state;
    }


    private void sendData(IRSAHandshakeStartCMD cmd ) {

        try {

            ByteBuf buf = IRSAHandshakeStartCMD.outputByteBuf(cmd);

            channel.writeAndFlush(buf);

        } catch (Throwable ex) {
            closeUnexpectedly(ex, "sendData error");
        }
    }


    /**
     * 握手状态
     * 每一个值代表连接下一步要执行的工作
     */
    enum HandshakeState {
        STATE_INIT_RSA_HANDERS,
        STATE_WAIT_RECV_CMD,
        STATE_EXIT;
    }

}
